from invoke import task
from storage import create_container
from config import load_config
import logging


def upload_data_from_to(
    c, remote_path, local_path, container_name, account_name, account_key
):
    cmd = (
        f"azcopy --source {local_path} --destination  https://{account_name}.blob.core.windows.net/{container_name}/{remote_path} "
        f"--dest-key {account_key} --quiet --recursive --exclude-older"
    )
    c.run(cmd, pty=True)


@task(pre=[create_container])
def upload_training_data(c):
    """Upload training data to container specified in .env file
    """
    env_values = load_config()
    container_name = env_values.get("CONTAINER_NAME")
    account_name = env_values.get("ACCOUNT_NAME")
    account_key = env_values.get("ACCOUNT_KEY")
    upload_data_from_to(
        c, "train", "/data/train", container_name, account_name, account_key
    )


@task(pre=[create_container])
def upload_validation_data(c):
    """Upload validation data to container specified in .env file
    """
    env_values = load_config()
    container_name = env_values.get("CONTAINER_NAME")
    account_name = env_values.get("ACCOUNT_NAME")
    account_key = env_values.get("ACCOUNT_KEY")
    upload_data_from_to(
        c, "validation", "/data/validation", container_name, account_name, account_key
    )


@task(pre=[upload_training_data, upload_validation_data])
def upload_data(c):
    """Upload training and validation data to container specified in .env file
    """
    print("Data uploaded")


def download_data_from_to(
    c, remote_path, local_path, container_name, account_name, account_key
):
    cmd = (
        f"azcopy --source https://{account_name}.blob.core.windows.net/{container_name}/{remote_path}  --destination {local_path} "
        f"--source-key {account_key} --quiet --recursive --exclude-older"
    )
    c.run(cmd, pty=True)


@task
def download_training(c):
    """Download training data from blob container specified in .env file
    """
    env_values = load_config()
    container_name = env_values.get("CONTAINER_NAME")
    account_name = env_values.get("ACCOUNT_NAME")
    account_key = env_values.get("ACCOUNT_KEY")
    download_data_from_to(
        c, "train", "/data/train", container_name, account_name, account_key
    )


@task
def download_validation(c):
    """Download validation data from blob container specified in .env file
    """
    env_values = load_config()
    container_name = env_values.get("CONTAINER_NAME")
    account_name = env_values.get("ACCOUNT_NAME")
    account_key = env_values.get("ACCOUNT_KEY")
    download_data_from_to(
        c, "validation", "/data/validation", container_name, account_name, account_key
    )


@task(pre=[download_training, download_validation])
def download_data(c):
    """Download training and validation data from blob container specified in .env file
    """
    print("Data downloaded")


@task
def prepare_imagenet(c, download_dir="/data", target_dir="/data"):
    """Prepare imagenet data found in download_dir and push results to target_dir
    
    Args:
        download_dir (str, optional): Location where imagenet tar file should be found. Defaults to "/data".
        target_dir (str, optional): Location where to copy uncompressed imagenet data to. Defaults to "/data".
    """
    from prepare_imagenet import main as prepare_imagenet_data
    logger = logging.getLogger(__name__)
    prepare_imagenet_data(download_dir, target_dir, checksum=False)
    logger.info("Data preparation complete")